{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Detailed explanation:\n",
    "https://medium.com/@alexppppp/how-to-train-a-custom-keypoint-detection-model-with-pytorch-d9af90e111da\n",
    "\n",
    "GitHub repo:\n",
    "https://github.com/alexppppp/keypoint_rcnn_training_pytorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, json, cv2, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import torchvision\n",
    "from torchvision.models.detection.rpn import AnchorGenerator\n",
    "from torchvision.transforms import functional as F\n",
    "\n",
    "import albumentations as A # Library for augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://github.com/pytorch/vision/tree/main/references/detection\n",
    "import transforms, utils, engine, train\n",
    "from utils import collate_fn\n",
    "from engine import train_one_epoch, evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_transform():\n",
    "    return A.Compose([\n",
    "        A.Sequential([\n",
    "            A.RandomRotate90(p=1), # Random rotation of an image by 90 degrees zero or more times\n",
    "            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True, always_apply=False, p=1), # Random change of brightness & contrast\n",
    "        ], p=1)\n",
    "    ],\n",
    "    keypoint_params=A.KeypointParams(format='xy'), # More about keypoint formats used in albumentations library read at https://albumentations.ai/docs/getting_started/keypoints_augmentation/\n",
    "    bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bboxes_labels']) # Bboxes should have labels, read more here https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Dataset class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassDataset(Dataset):\n",
    "    def __init__(self, root, transform=None, demo=False):                \n",
    "        self.root = root\n",
    "        self.transform = transform\n",
    "        self.demo = demo # Use demo=True if you need transformed and original images (for example, for visualization purposes)\n",
    "        self.imgs_files = sorted(os.listdir(os.path.join(root, \"images\")))\n",
    "        self.annotations_files = sorted(os.listdir(os.path.join(root, \"annotations\")))\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        img_path = os.path.join(self.root, \"images\", self.imgs_files[idx])\n",
    "        annotations_path = os.path.join(self.root, \"annotations\", self.annotations_files[idx])\n",
    "\n",
    "        img_original = cv2.imread(img_path)\n",
    "        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)        \n",
    "        \n",
    "        with open(annotations_path) as f:\n",
    "            data = json.load(f)\n",
    "            bboxes_original = data['bboxes']\n",
    "            keypoints_original = data['keypoints']\n",
    "            \n",
    "            # All objects are glue tubes\n",
    "            bboxes_labels_original = ['Glue tube' for _ in bboxes_original]            \n",
    "\n",
    "        if self.transform:   \n",
    "            # Converting keypoints from [x,y,visibility]-format to [x, y]-format + Flattening nested list of keypoints            \n",
    "            # For example, if we have the following list of keypoints for three objects (each object has two keypoints):\n",
    "            # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]], where each keypoint is in [x, y]-format            \n",
    "            # Then we need to convert it to the following list:\n",
    "            # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2]\n",
    "            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]\n",
    "            \n",
    "            # Apply augmentations\n",
    "            transformed = self.transform(image=img_original, bboxes=bboxes_original, bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)\n",
    "            img = transformed['image']\n",
    "            bboxes = transformed['bboxes']\n",
    "            \n",
    "            # Unflattening list transformed['keypoints']\n",
    "            # For example, if we have the following list of keypoints for three objects (each object has two keypoints):\n",
    "            # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2], where each keypoint is in [x, y]-format\n",
    "            # Then we need to convert it to the following list:\n",
    "            # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]]\n",
    "            keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']), (-1,2,2)).tolist()\n",
    "\n",
    "            # Converting transformed keypoints from [x, y]-format to [x,y,visibility]-format by appending original visibilities to transformed coordinates of keypoints\n",
    "            keypoints = []\n",
    "            for o_idx, obj in enumerate(keypoints_transformed_unflattened): # Iterating over objects\n",
    "                obj_keypoints = []\n",
    "                for k_idx, kp in enumerate(obj): # Iterating over keypoints in each object\n",
    "                    # kp - coordinates of keypoint\n",
    "                    # keypoints_original[o_idx][k_idx][2] - original visibility of keypoint\n",
    "                    obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])\n",
    "                keypoints.append(obj_keypoints)\n",
    "        \n",
    "        else:\n",
    "            img, bboxes, keypoints = img_original, bboxes_original, keypoints_original        \n",
    "        \n",
    "        # Convert everything into a torch tensor        \n",
    "        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)       \n",
    "        target = {}\n",
    "        target[\"boxes\"] = bboxes\n",
    "        target[\"labels\"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64) # all objects are glue tubes\n",
    "        target[\"image_id\"] = torch.tensor([idx])\n",
    "        target[\"area\"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])\n",
    "        target[\"iscrowd\"] = torch.zeros(len(bboxes), dtype=torch.int64)\n",
    "        target[\"keypoints\"] = torch.as_tensor(keypoints, dtype=torch.float32)        \n",
    "        img = F.to_tensor(img)\n",
    "        \n",
    "        bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)\n",
    "        target_original = {}\n",
    "        target_original[\"boxes\"] = bboxes_original\n",
    "        target_original[\"labels\"] = torch.as_tensor([1 for _ in bboxes_original], dtype=torch.int64) # all objects are glue tubes\n",
    "        target_original[\"image_id\"] = torch.tensor([idx])\n",
    "        target_original[\"area\"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (bboxes_original[:, 2] - bboxes_original[:, 0])\n",
    "        target_original[\"iscrowd\"] = torch.zeros(len(bboxes_original), dtype=torch.int64)\n",
    "        target_original[\"keypoints\"] = torch.as_tensor(keypoints_original, dtype=torch.float32)        \n",
    "        img_original = F.to_tensor(img_original)\n",
    "\n",
    "        if self.demo:\n",
    "            return img, target, img_original, target_original\n",
    "        else:\n",
    "            return img, target\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.imgs_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Visualizing a random item from dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "KEYPOINTS_FOLDER_TRAIN = 'glue_tubes_keypoints_dataset_134imgs/train'\n",
    "dataset = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=True)\n",
    "data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)\n",
    "\n",
    "iterator = iter(data_loader)\n",
    "batch = next(iterator)\n",
    "\n",
    "print(\"Original targets:\\n\", batch[3], \"\\n\\n\")\n",
    "print(\"Transformed targets:\\n\", batch[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "keypoints_classes_ids2names = {0: 'Head', 1: 'Tail'}\n",
    "\n",
    "def visualize(image, bboxes, keypoints, image_original=None, bboxes_original=None, keypoints_original=None):\n",
    "    fontsize = 18\n",
    "\n",
    "    for bbox in bboxes:\n",
    "        start_point = (bbox[0], bbox[1])\n",
    "        end_point = (bbox[2], bbox[3])\n",
    "        image = cv2.rectangle(image.copy(), start_point, end_point, (0,255,0), 2)\n",
    "    \n",
    "    for kps in keypoints:\n",
    "        for idx, kp in enumerate(kps):\n",
    "            image = cv2.circle(image.copy(), tuple(kp), 5, (255,0,0), 10)\n",
    "            image = cv2.putText(image.copy(), \" \" + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_SIMPLEX, 2, (255,0,0), 3, cv2.LINE_AA)\n",
    "\n",
    "    if image_original is None and keypoints_original is None:\n",
    "        plt.figure(figsize=(40,40))\n",
    "        plt.imshow(image)\n",
    "\n",
    "    else:\n",
    "        for bbox in bboxes_original:\n",
    "            start_point = (bbox[0], bbox[1])\n",
    "            end_point = (bbox[2], bbox[3])\n",
    "            image_original = cv2.rectangle(image_original.copy(), start_point, end_point, (0,255,0), 2)\n",
    "        \n",
    "        for kps in keypoints_original:\n",
    "            for idx, kp in enumerate(kps):\n",
    "                image_original = cv2.circle(image_original, tuple(kp), 5, (255,0,0), 10)\n",
    "                image_original = cv2.putText(image_original, \" \" + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_SIMPLEX, 2, (255,0,0), 3, cv2.LINE_AA)\n",
    "\n",
    "        f, ax = plt.subplots(1, 2, figsize=(40, 20))\n",
    "\n",
    "        ax[0].imshow(image_original)\n",
    "        ax[0].set_title('Original image', fontsize=fontsize)\n",
    "\n",
    "        ax[1].imshow(image)\n",
    "        ax[1].set_title('Transformed image', fontsize=fontsize)\n",
    "        \n",
    "image = (batch[0][0].permute(1,2,0).numpy() * 255).astype(np.uint8)\n",
    "bboxes = batch[1][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()\n",
    "\n",
    "keypoints = []\n",
    "for kps in batch[1][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():\n",
    "    keypoints.append([kp[:2] for kp in kps])\n",
    "\n",
    "image_original = (batch[2][0].permute(1,2,0).numpy() * 255).astype(np.uint8)\n",
    "bboxes_original = batch[3][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()\n",
    "\n",
    "keypoints_original = []\n",
    "for kps in batch[3][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():\n",
    "    keypoints_original.append([kp[:2] for kp in kps])\n",
    "\n",
    "visualize(image, bboxes, keypoints, image_original, bboxes_original, keypoints_original)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(num_keypoints, weights_path=None):\n",
    "    \n",
    "    anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))\n",
    "    model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,\n",
    "                                                                   pretrained_backbone=True,\n",
    "                                                                   num_keypoints=num_keypoints,\n",
    "                                                                   num_classes = 2, # Background is the first class, object is the second class\n",
    "                                                                   rpn_anchor_generator=anchor_generator)\n",
    "\n",
    "    if weights_path:\n",
    "        state_dict = torch.load(weights_path)\n",
    "        model.load_state_dict(state_dict)        \n",
    "        \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "KEYPOINTS_FOLDER_TRAIN = 'glue_tubes_keypoints_dataset_134imgs/train'\n",
    "KEYPOINTS_FOLDER_TEST = 'glue_tubes_keypoints_dataset_134imgs/test'\n",
    "\n",
    "dataset_train = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=False)\n",
    "dataset_test = ClassDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)\n",
    "\n",
    "data_loader_train = DataLoader(dataset_train, batch_size=3, shuffle=True, collate_fn=collate_fn)\n",
    "data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)\n",
    "\n",
    "model = get_model(num_keypoints = 2)\n",
    "model.to(device)\n",
    "\n",
    "params = [p for p in model.parameters() if p.requires_grad]\n",
    "optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)\n",
    "num_epochs = 5\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=1000)\n",
    "    lr_scheduler.step()\n",
    "    evaluate(model, data_loader_test, device)\n",
    "    \n",
    "# Save model weights after training\n",
    "torch.save(model.state_dict(), 'keypointsrcnn_weights.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6. Visualizing model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = iter(data_loader_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images, targets = next(iterator)\n",
    "images = list(image.to(device) for image in images)\n",
    "\n",
    "with torch.no_grad():\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    output = model(images)\n",
    "\n",
    "print(\"Predictions: \\n\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = (images[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)\n",
    "scores = output[0]['scores'].detach().cpu().numpy()\n",
    "\n",
    "high_scores_idxs = np.where(scores > 0.7)[0].tolist() # Indexes of boxes with scores > 0.7\n",
    "post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy() # Indexes of boxes left after applying NMS (iou_threshold=0.3)\n",
    "\n",
    "# Below, in output[0]['keypoints'][high_scores_idxs][post_nms_idxs] and output[0]['boxes'][high_scores_idxs][post_nms_idxs]\n",
    "# Firstly, we choose only those objects, which have score above predefined threshold. This is done with choosing elements with [high_scores_idxs] indexes\n",
    "# Secondly, we choose only those objects, which are left after NMS is applied. This is done with choosing elements with [post_nms_idxs] indexes\n",
    "\n",
    "keypoints = []\n",
    "for kps in output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():\n",
    "    keypoints.append([list(map(int, kp[:2])) for kp in kps])\n",
    "\n",
    "bboxes = []\n",
    "for bbox in output[0]['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():\n",
    "    bboxes.append(list(map(int, bbox.tolist())))\n",
    "    \n",
    "visualize(image, bboxes, keypoints)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
